import numpy as np
import yaml
import pandas as pd
#import h5py
import os
import pickle
import pyodbc
import json
from json import JSONEncoder
import logging
import time
from sklearn.preprocessing import RobustScaler, MaxAbsScaler, StandardScaler
from scipy.stats.mstats import winsorize
from pandas.api.types import CategoricalDtype
import hashlib


class Winsorizer:
    """Sklearn-style scaler that winsorizes and optionally standardizes."""
    
    def __init__(self, standardize=True, min_val=None, max_val=None,
                 mean_adj=False, **kwargs):
        self.kwargs = kwargs
        self.standardize = standardize
        self.min_val = min_val
        self.max_val = max_val
        self.mean_adj = mean_adj
        
    def fit(self, X):
        """Remove columns not corresponding to acceptable types.

        Args:
            X (pd.DataFrame, ndarray): DataFrame or array (rows x features)

        Returns:
            self
        """
        self.transform(X)
        return self
    
    def fit_transform(self, X):
        return self.transform(X)

    def transform(self, X):
        if self.min_val is not None or self.max_val is not None:
            X = np.clip(X, self.min_val, self.max_val)
        
        if self.standardize:
            self.scaler = StandardScaler()
            self.scaler.fit(winsorize(X, **self.kwargs))
            return self.scaler.transform(winsorize(X, **self.kwargs))
        else:
            pre_mean = np.mean(X)
            transformed = winsorize(X, **self.kwargs)
            post_mean = np.mean(transformed)
            self.scaler = pre_mean / post_mean
            return transformed
                
    def inverse_transform(self, X):
        if self.standardize:
            recon = self.scaler.inverse_transform(np.ravel(X))
        else:
            recon = np.ravel(X)
            if self.mean_adj:
                recon = recon * self.scaler
        return np.clip(recon, self.min_val, self.max_val)
       

class Digitizer:
    def __init__(self, bins):
        self.bins = bins
        self.class_means = np.ones(len(bins) - 1)
        self.classes_ = np.arange(len(bins) - 1)
        
    def fit(self, X):
        dig = np.digitize(X, self.bins) - 1
        for cls in self.classes_:
            self.class_means[cls] = np.nanmean(X[dig == cls])
        return self
       
    def fit_transform(self, X):
        self.fit(X)
        return self.transform(X)
    
    def transform(self, X):
        return np.digitize(X, self.bins) - 1
    
    def inverse_transform(self, X):
        return np.array([self.class_means[cls] for cls in np.ravel(X)])

        
def restrict_types(df, allowed_types=['int64', 'float64']):
    """Remove columns not corresponding to acceptable types.

    Args:
        df (pd.DataFrame): DataFrame to reduce
        allowed_types (list): acceptable data types

    Returns:
        pd.DataFrame minus unwanted columns.
    """
    drop = []
    for col in df.columns:
        t = str(df[col].dtype)
        if t not in allowed_types:
            fixed = 0
            # convert potential string descriptions to real types
            for allowed_type in allowed_types:
                df[col] = df[col].astype(allowed_type)
                logging.debug(f'Replaced type of {col} with {allowed_type}.')
                fixed = 1
                break
            if fixed:
                continue
            logging.debug(f'Dropping column {col} with disallowed type {t}.')
            drop += [col]
                        
    return df.drop(columns=drop, inplace=False)


def expand_categorical(df, cat_dict):
    """Replace specified categorical columns with onehots.

    Args:
        df (pd.DataFrame): DataFrame to alter
        cat_cols (list): columns to expand

    Returns:
        pd.DataFrame
    """
        
    for col in cat_dict:
        found = np.unique(df[col])
        assert all([unq in cat_dict[col] for unq in found]), \
            logging.critical('Loaded data contains categorical features ' \
                             f"outside of specification for {col}: {', '.join(found.astype(str))}")
        dt = CategoricalDtype(cat_dict[col], ordered=True)
        expanded = pd.get_dummies(df[col].astype(dt), prefix=col)
        logging.debug(f'Expanded features with categorical columns {expanded.columns.values}' + \
                      f' of type {dt}')
        df = df.join(expanded, how='outer')

    return df.drop(columns=[col for col in cat_dict], inplace=False)

def collapse_categorical(df, cols):
    """Replace specified multiclass categorical columns with binary indicators.
    
    Args:
        df (pd.DataFrame): DataFrame to alter
        cols: Columns to collapse 
        
    Returns:
        pd.DataFrame
    """
    
    for col in cols:
        df[col] = [0 if x else 1 for x in df[col].isnull()]
    return df


def drop_columns(df, cols):
    return df.drop(columns=np.intersect1d(cols, df.columns), inplace=False)


def csv_to_h5(filepath, out=None, dtypes=None, chunksize=10000,
              min_itemsize=50, complevel=0):
    """Convert a csv file to a h5.  Simplifies access and allows for
    potentially higher i/o.  Will automatically append if file exists.

    Args:
        filepath (string): path to csv
        out (string): path to h5
        dtypes (dict): dictionary matching features to types.  If none, will infer
            from csv load.
        chunksize (int): maximum number of rows to process at a time.
        min_itemsize (int): padded length of item string default (effectively 
            maximum string size).
        complevel (int): compression level to use in h5 storage [0-9].

    Returns:
        None
    """
        
    if out is None:
        out = filepath.split('.')[0] + '.h5'
    
    with open(filepath, 'r') as f:
        data_len = sum([1 for line in f])

    if dtypes is None:
        data = pd.read_csv(filepath, low_memory=False, nrows=min(data_len, 2000000))
        dtypes = dict(data.dtypes)
        del data
        
    start = 0
    
    while start < data_len:
        stop = int(np.clip(start + chunksize, None, data_len))                 

        skiprows = np.concatenate((np.arange(1, start), np.arange(stop, data_len)))
        data = pd.read_csv(filepath, skiprows=skiprows,
                           dtype=dtypes, low_memory=False, header=0, engine='c')
        data.reindex(np.arange(start, stop))
        data.index = np.arange(start, len(data)+start)
        if start == 0:
            try:
                data.to_hdf(out, 'data', mode='r+', format='table', index=True,
                            min_itemsize=min_itemsize, complevel=complevel, append=True,
                            data_columns=True)
                logging.debug(f'Starting new h5 with rows {start}:{stop} from file {filepath}.')
            except OSError:
                logging.debug(f'Appending rows {start}:{stop} from file {filepath} to h5.')
                data.to_hdf(out, 'data', mode='a', format='table', index=True,
                            min_itemsize=min_itemsize, complevel=complevel)
        else:
            logging.debug(f'Appending rows {start}:{stop} from file {filepath} to h5.')
            data.to_hdf(out, 'data', mode='a', format='table', index=True,
                        append=True, complevel=complevel)
           
        start += chunksize

    return dtypes


def lower_config(config):
    """Transform a data config dictionary to have all lower case field names.

    Args:
        config (dict): config dict meant for LoadedInterface

    Returns:
        config (dict)
    """
    logging.debug('Lowering case of all fields.')
    for field in ['reward', 'target', 'sample_weight', 'fit_weight',
                  'categorical_features', 'drop_features', 'meta_features']:
        if type(config[field]) == str:
            config[field] = config[field].lower()
        elif type(config[field]) == list:
            config[field] = [key.lower() for key in config[field]]
        elif type(config[field]) == dict:
            config[field] = {key.lower(): config[field][key] 
                             for key in config[field]}
        else:
            pass
            
    return config
        
        
class LoadedInterface:
    """Provides in-memory data interface to either files or database."""
    
    def __init__(self, config, rows=None, query=None, load_scalers=False, lower_case=True):
        
        self.query_database = False
        self.config = config
        self.load_scalers = load_scalers
        
        self.features = None
        self.target = None
        self.reward = None
        self.meta = None
        self.sample_weight = None
        self.fit_weight = None
        
        self.target_scaler = None
        self.reward_scaler = None
        self.feature_scaler = None
                            
        if type(config) == str:
            with open(config, 'r') as f:
                config = yaml.safe_load(f)
        
        if lower_case:
            self.lower_case = lower_case
            config = lower_config(config)
        
        if config['query_database']:
            '''
		REDACTED
		'''
        elif config['raw_file'].endswith('.csv') and not config['query_database']:
            logging.debug('Loading data from csv.')
            if rows is None:
                df = pd.read_csv(config['raw_file'], low_memory=False)
            else:
                df = pd.read_csv(config['raw_file'], low_memory=False, skiprows=lambda x: x not in rows)
        elif config['raw_file'].endswith('.h5') and not config['query_database']:
            logging.debug('Loading data from h5.')
            if rows is None:
                hdf = pd.HDFStore(config['raw_file'])
                df = hdf.select('data')
                hdf.close()
            else:
                hdf = pd.HDFStore(config['raw_file'])
                df = hdf.select('data', rows)
                hdf.close()
        
        if lower_case:
            df.columns = df.columns.str.lower()
                                
        # todo add this to the loading mechanism(s) in a way that it can robustly handle case
        if config['use_dep_database']:
            full_features = config['features'] + config['dep_database_features'] + config['meta_features']
        else:
            full_features = config['features'] + config['meta_features']
        for field in ['target', 'reward', 'sample_weight', 'fit_weight']:
            if config[field] is not None:
                full_features.append(config[field])
        full_features = np.unique(full_features)
        
        # subset to what exists
        df = df[np.intersect1d(full_features, df.columns.values)]
        # add what doesn't as nan
        for col in np.setdiff1d(full_features, df.columns.values):
            logging.debug(f'Adding in missing feature with nans: {col}')
            df[col] = np.nan
        
        # features that may have multiple appearances under alias -- fill in gaps in both
        if 'replicate_features' in config:
            for col in config['replicate_features']:
                sister_col = config['replicate_features'][col]
                logging.debug(f"Filling in gaps using features {col} and {sister_col}")
                df.loc[df[col].isna(), col] = df.loc[df[col].isna(), sister_col]
                df.loc[df[sister_col].isna(), sister_col] = df.loc[df[sister_col].isna(), col]
        
        if 'rename_features' in config:
            logging.debug(f"Renaming features according to dict: {config['rename_features']}")
            df.rename(columns=config['rename_features'], inplace=True)
        
        if 'fill_na' in config:
            
            specific_reps = {key:config['fill_na'][key] for key in config['fill_na'] if key != 'all'}
            logging.debug(f"Filling NaNs according to dict: {specific_reps}")
            df.fillna(value=specific_reps, inplace=True)

            # always fill 'all' last
            if 'all' in config['fill_na']:
                logging.debug("Filling remaining NaNs with zero.")
                df.fillna(config['fill_na']['all'], inplace=True)
           
        if 'zfill' in config:
            for col in config['zfill']:
                logging.debug(f"Z-filling for {col}.")
                df[col] = df[col].apply(lambda x: str(x).zfill(config['zfill'][col]))
                
        if 'rules' in config:
            logging.debug('Collapsing multi-category rules into binary inds.')
            df = collapse_categorical(df, config['rules'])
                      
        # todo add all this behind features in case anybody wants categoricals
        if config['meta_features'] is not None:
            self.meta = df.loc[:, config['meta_features']]
            if len(config['meta_features']) == 1:
                self.meta = pd.DataFrame(self.meta.values[:, 0], index=df.index,
                                         columns=[config['meta_features']])
            
        if config['sample_weight'] is None:
            self.sample_weight = pd.Series(np.ones(len(self.meta)), index=df.index)
        else:
            logging.debug(f"Preparing to weight samples for pop est. by {config['sample_weight']}.")
            sample_weight = df.loc[:, config['sample_weight']]
            self.sample_weight = pd.Series(sample_weight, index=df.index)
        
        if config['fit_weight'] is None:
            self.fit_weight = pd.Series(np.ones(len(self.meta)), index=df.index)
        else:
            logging.debug(f"Preparing to weight samples for fitting by {config['fit_weight']}.")
            fit_weight = df.loc[:, config['fit_weight']]
            self.fit_weight = pd.Series(fit_weight, index=df.index)
        
        if config['target'] is not None:
            if config['target'] not in config['drop_features']:
                logging.debug(f"Dropping specified target {config['target']} from feature set.")
                config['drop_features'] += [config['target']]
            
            self.target = df.loc[:, config['target']]
            
            if 'target_rescale' not in config:
                config['target_rescale'] = False
            
            if config['target_rescale']:
                if load_scalers:
                    logging.debug(f"Loading target scaler from {config['target_scaler']}.")
                    with open(config['target_scaler']['save_path'], 'rb') as f:
                        scaler = pickle.load(f)
                else:
                    logging.debug('Fitting target scaler.')
                    scaler = globals()[config['target_scaler']['name']](**config['target_scaler']['kwargs'])
                    scaler = scaler.fit(self.target.values.reshape(-1, 1)) # handles single-feature
                    with open(config['target_scaler']['save_path'], 'wb') as f:
                        pickle.dump(scaler, f)
                self.target_scaler = scaler
                target = scaler.transform(self.target.values.reshape(-1, 1))[:, 0] # handles single-feature
                self.target = pd.DataFrame(target, index=df.index, columns=[config['target']])
        
        if config['reward'] is not None:
            # ensure the reward is never in the feature space
            if config['reward'] not in config['drop_features']:
                logging.debug(f"Dropping specified reward {config['reward']} from feature set.")
                config['drop_features'] += [config['reward']]
            
            self.reward = df.loc[:, config['reward']]   
            
            if 'reward_rescale' not in config:
                config['reward_rescale'] = False
            
            if config['reward_rescale']:                
                if load_scalers:
                    logging.debug(f"Loading reward scaler from {config['reward_scaler']}.")
                    with open(config['reward_scaler']['save_path'], 'rb') as f:
                        scaler = pickle.load(f)
                else: 
                    logging.debug('Fitting target scaler.')
                    scaler = globals()[config['reward_scaler']['name']](**config['reward_scaler']['kwargs'])
                    scaler = scaler.fit(self.reward.values.reshape(-1, 1)) # handles single-feature
                    with open(config['reward_scaler']['save_path'], 'wb') as f:
                        pickle.dump(scaler, f)
                self.reward_scaler = scaler
                reward = scaler.transform(self.reward.values.reshape(-1, 1))[:, 0] # handles single-feature
                self.reward = pd.DataFrame(reward, index=self.reward.index, columns=[config['reward']])
                
        # todo remove if, pretty uncommon now
        if config['drop_features'] != 'all':
            logging.debug(f"Dropping columns from feature set: {config['drop_features']}")
            df = drop_columns(df, config['drop_features'])
            ############ come back and revert this!!!
            #df = expand_categorical(df, config['categorical_features'])
            df = df.drop(columns=config['categorical_features'])
            df = restrict_types(df, config['allowed_types'])
            self.features = df
            
            if config['feature_rescale']:
                if load_scalers:
                    logging.debug(f"Loading feature scaler from {config['feature_scaler']}.")
                    with open(config['feature_scaler']['save_path'], 'rb') as f:
                        scaler = pickle.load(f)
                else:
                    logging.debug('Fitting target scaler.')
                    scaler = globals()[config['feature_scaler']['name']](**config['feature_scaler']['kwargs'])
                    scaler = scaler.fit(self.features)
                    with open(config['feature_scaler']['save_path'], 'wb') as f:
                        pickle.dump(scaler, f)
                self.feature_scaler = scaler
                self.features.loc[:] = scaler.transform(self.features)
                
        logging.debug("Sorting features alphabetically.")
        self.features = self.features.reindex(sorted(self.features.columns), axis=1)
        logging.debug(f"Final feature set includes: {self.features.columns.values}")
        logging.debug(f"Features have types: {np.unique([t for t in self.features.dtypes.values])}")
        
        self.data_components = []
        for comp in ['features', 'target', 'meta', 'reward', 'sample_weight', 'fit_weight']:
            if getattr(self, comp) is not None:
                self.data_components += [comp]
        
    def get(self, query, as_array=True):
        """
        Return data that fulfills the given query.

        Args:
            query (str or list of str): Query statement(s) for pandas table.
            as_array (bool): Converts output to ndarray (default True)

        Returns:
            inds (pd.Index or ndarray): Indices of new data.
            features (pd.DataFrame or ndarray): NxM dataframe representing new features.
            target (pd.DataFrame or ndarray): Nx1 dataframe representing target feature.
            meta (pd.DataFrame or ndarray): NxL dataframe representing meta features.
            reward (pd.Dataframe or ndarray): NxK dataframe representing reward feature(s).
            sample_weight (pd.DataFrame or ndarray): Nx1 dataframe representing target weights.
            fit_weight (pd.DataFrame or ndarray): Nx1 dataframe representing fit weights.
        """  
        
        # query must be series of selection criteria
        if type(query) == list:
            query = ' and '.join([cut for cut in query])
             
        if query is None:
            inds = self.meta.index
        else:
            if self.lower_case:
                query = query.lower()
            inds = self.meta.query(query).index
            
        out = [inds]
        for comp in self.data_components:
            if as_array:
                out += [getattr(self, comp).loc[inds].values]
            else:
                out += [getattr(self, comp).loc[inds]]
        
        return out

    
def hash_file(filepath, blocksize=65536):
    hasher = hashlib.md5()
    with open(filepath, 'rb') as f:
        buffer = f.read(blocksize)
        while len(buffer) > 0:
            hasher.update(buffer)
            buffer = f.read(blocksize)
    return hasher.hexdigest()
